option(MNN_TRAIN_DEBUG "Enable MNN Train Grad Debug" OFF)
option(MNN_USE_OPENCV "Use opencv" OFF)

include_directories(${CMAKE_CURRENT_LIST_DIR}/source/grad)
include_directories(${CMAKE_CURRENT_LIST_DIR}/source/optimizer)
include_directories(${CMAKE_CURRENT_LIST_DIR}/source/transformer)
include_directories(${CMAKE_CURRENT_LIST_DIR}/source/data)
include_directories(${CMAKE_CURRENT_LIST_DIR}/source/nn)
include_directories(${CMAKE_CURRENT_LIST_DIR}/source/models)
include_directories(${CMAKE_CURRENT_LIST_DIR}/source/datasets)
SET( CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/../../)
file(GLOB GRAD ${CMAKE_CURRENT_LIST_DIR}/source/grad/*)
file(GLOB TRANSFORMER ${CMAKE_CURRENT_LIST_DIR}/source/transformer/*)
file(GLOB OPTIMIZER ${CMAKE_CURRENT_LIST_DIR}/source/optimizer/*)
file(GLOB DATALOADER ${CMAKE_CURRENT_LIST_DIR}/source/data/*)
file(GLOB MODELS ${CMAKE_CURRENT_LIST_DIR}/source/models/*)
file(GLOB NNFILES ${CMAKE_CURRENT_LIST_DIR}/source/nn/*)
file(GLOB DATASETS ${CMAKE_CURRENT_LIST_DIR}/source/datasets/*)

# MNNTrain
IF (MNN_TRAIN_DEBUG)
    add_definitions(-DMNN_TRAIN_DEBUG)
ENDIF()
set(MNN_TRAIN_SRCS ${GRAD} ${BASIC_INCLUDE} ${OPTIMIZER} ${TRANSFORMER} ${NNFILES})
set(MNN_TRAIN_UTILS_SRCS ${MODELS} ${DATASETS} ${DATALOADER})
IF(MNN_SEP_BUILD)
    add_library(MNNTrain SHARED ${MNN_TRAIN_SRCS})
    add_library(MNNTrainUtils SHARED ${MNN_TRAIN_UTILS_SRCS})
    target_link_libraries(MNNTrain MNN MNN_Express)
    target_link_libraries(MNNTrainUtils MNNTrain)
ELSE()
    add_library(MNNTrain OBJECT ${MNN_TRAIN_SRCS})
    add_library(MNNTrainUtils OBJECT ${MNN_TRAIN_UTILS_SRCS})
ENDIF()
target_compile_definitions(MNNTrainUtils PRIVATE STB_IMAGE_STATIC STB_IMAGE_IMPLEMENTATION)

# executables
set(MNN_TRAIN_TOOLS "")
add_executable(transformer ${CMAKE_CURRENT_LIST_DIR}/source/exec/transformerExecution.cpp)
add_executable(extractForInfer ${CMAKE_CURRENT_LIST_DIR}/source/exec/extractForInfer.cpp)
file(GLOB DEMOSOURCE ${CMAKE_CURRENT_LIST_DIR}/source/demo/*)
add_executable(runTrainDemo.out ${DEMOSOURCE} ${BASIC_INCLUDE})
target_include_directories(runTrainDemo.out PRIVATE ../../3rd_party/imageHelper/)

list(APPEND MNN_TRAIN_TOOLS transformer)
list(APPEND MNN_TRAIN_TOOLS extractForInfer)
list(APPEND MNN_TRAIN_TOOLS runTrainDemo.out)
foreach(TARGET ${MNN_TRAIN_TOOLS})
    target_link_libraries(${TARGET} PRIVATE ${MNN_DEPS})
    IF(MNN_SEP_BUILD)
        target_link_libraries(${TARGET} PRIVATE MNNTrain)
        target_link_libraries(${TARGET} PRIVATE MNNTrainUtils)
    ENDIF()
    if (MSVC)
        target_compile_definitions(${TARGET} PRIVATE "_CRT_SECURE_NO_WARNINGS")
        if (NOT MNN_BUILD_SHARED_LIBS)
            foreach (DEPEND ${MNN_DEPS})
                target_link_options(${TARGET} PRIVATE /WHOLEARCHIVE:$<TARGET_FILE:${DEPEND}>)
            endforeach ()
        endif()
    endif()
endforeach()

if (MNN_USE_OPENCV)
    set(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS_ORIGIN})
    set(CMAKE_C_FLAGS ${CMAKE_C_FLAGS_ORIGIN})
    find_package(OpenCV REQUIRED)
    target_include_directories(runTrainDemo.out PRIVATE ${OpenCV_INCLUDE_DIRS})
    add_definitions(-DMNN_USE_OPENCV)
    target_link_libraries(runTrainDemo.out PRIVATE ${OpenCV_LIBS})
endif()
